-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gpt_bigcode: added FusedSDPA kernel #1138
Conversation
80da185
to
416ad8a
Compare
Original implementation
Stats
FusedSDPA
Stats
|
optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgonchar , thank you for your PR.
- Please run make style prior to future submissions, it takes care of code formatting fixes
- Please verify if bigcode/starcoder which also uses this model file runs alright for you. I am seeing an error
RuntimeError: Graph compile failed. synStatus=synStatus 26 [Generic failure].
forpython run_generation.py --model_name_or_path bigcode/starcoder --batch_size 1 --use_hpu_graphs --use_kv_cache --max_new_tokens 100 --bf16 --use_flash_attention
. I suspect this has to do with "starcoderbase" check and it works if substituted with "starcoder" but we need a better match because we want to avoid starcoder2 perhaps? - Testing: Please add a test for starcoder, starcoderbase with flash attention options in tests/test_text_generation_example.py
I will do a another pass after we resolve these
eb41c76
to
815c896
Compare
@vidyasiv I've updated this PR, based on your feedback. Please have a look |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor typo.I was able to run tests and so far LGTM
optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Outdated
Show resolved
Hide resolved
815c896
to
4fbb0bb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please re-run make style
4fbb0bb
to
41edbba
Compare
@vidyasiv done |
@regisss , please take a look |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
- Added support of following options to gpt_bigcode (starcoder class of models) use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask - Updated test for starcoder model
41edbba
to
7aa7a8e
Compare
PR rebased. I've rechecked rebased code - no regressions found. @regisss please have a look |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Added support of following options to gpt_bigcode (starcoderbase) model use_flash_attention,
flash_attention_recompute,
flash_attention_fast_softmax,
flash_attention_causal_mask
Before submitting